# ar_sim/common/kernel_builder.py

import numpy as np

def build_reproduction_kernel(n_vals: np.ndarray,
                              pivot_params: dict,
                              sigma: float = 1.0) -> np.ndarray:
    """
    Construct M_ij = g(D_i) · exp[-(n_i - n_j)^2 / (2 σ²)] · g(D_j),
    with g(D) = a·D + b (pivot function).

    Parameters
    ----------
    n_vals : 1‑D ndarray
        Context indices (same order as your field vector).
    pivot_params : dict
        Must contain keys "a", "b", and "D_vals" (array‑like).
    sigma : float
        Width of the Gaussian kernel.

    Returns
    -------
    M : 2‑D ndarray (N × N)
        Symmetric reproduction kernel matrix.
    """
    # Extract pivot parameters
    a = pivot_params.get("a")
    b = pivot_params.get("b")
    D_vals = np.asarray(pivot_params.get("D_vals"))

    # Compute pivot weights
    g = a * D_vals + b

    # Pairwise squared differences (n_i - n_j)^2
    diff2 = (n_vals[:, None] - n_vals[None, :]) ** 2
    G = np.exp(-diff2 / (2.0 * sigma**2))

    # Scale by pivot weights on both axes
    M = g[:, None] * G * g[None, :]
    return M


def compute_kernel_eigenvalues(M: np.ndarray, k: int = 5):
    """
    Compute the top-k eigenvalues and eigenvectors of the symmetric kernel matrix M.

    Parameters
    ----------
    M : 2‑D ndarray
        Symmetric reproduction kernel matrix.
    k : int
        Number of largest eigenvalues/vectors to return.

    Returns
    -------
    vals : ndarray
        Top k eigenvalues (descending).
    vecs : ndarray
        Corresponding eigenvectors (N, k).
    """
    from scipy.linalg import eigh
    vals, vecs = eigh(M)
    top_vals = vals[-k:][::-1]
    top_vecs = vecs[:, -k:][:, ::-1]
    return top_vals, top_vecs

